"""
Phase 7: Reviewer Response — Robustness and Clarifications
==========================================================
Addresses five major and three secondary referee comments:

Part 1: Structural balance robustness (IMF gap, Hamilton filter, simple controls)
Part 2: r-g claims tightening (homogeneous yields, formal equality test)
Part 3: Fixed projections (truncate 2040, Bohn reaction, uncertainty bands)
Part 4: OADR units and marginal effects
Part 5: Secondary fixes (same-sample PB vs SB, Bohn β interpretation)

Input:  fiscal_dominance/data/processed/fiscal_panel.csv
        multilateral/data/processed/full_panel.csv (future demographics)
Output: fiscal_dominance/output/tables/phase7_*.md
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
FD_DIR = PROJECT_DIR / "fiscal_dominance"
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n{'=' * 70}")
    print(f"  {label}")
    print(f"  N={model.n_obs:,}, {model.n_countries} countries, "
          f"R²={model.r_squared:.4f}, rho={model.rho:.3f}")
    print(f"{'=' * 70}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    result_df['n_obs'] = model.n_obs
    result_df['n_countries'] = model.n_countries
    result_df['r_squared'] = model.r_squared
    result_df['rho'] = model.rho
    return model, result_df


def stars(p):
    if p < 0.01:
        return '***'
    elif p < 0.05:
        return '**'
    elif p < 0.1:
        return '*'
    return ''


def fmt_coef(coef, se, p):
    """Format coefficient with stars and SE in parentheses."""
    return f"{coef:.4f}{stars(p)} ({se:.4f})"


# =====================================================================
# PART 1: Structural Balance Robustness
# =====================================================================
def part1_structural_robustness(df):
    """Test structural Bohn reversal with alternative output gap measures."""
    print("\n" + "=" * 70)
    print("  PART 1: Structural Balance Robustness")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    results = []

    # --- 1a: IMF WEO output gap (27 countries) ---
    print("\n--- 1a: IMF WEO output gap ---")
    imf_countries = df.loc[df['output_gap'].notna(), 'iso3'].unique()
    print(f"  IMF output gap available: {len(imf_countries)} countries")

    # Structural Bohn with IMF gap
    imf_bohn_vars = ['debt_lag', 'output_gap', 'govt_exp_gap']
    imf_bohn_vars = [v for v in imf_bohn_vars if v in df.columns]
    for dep_label, dep_col in [('Primary Balance (IMF gap)', 'primary_bal_gdp'),
                                ('Structural Balance (IMF gap)', 'structural_bal_gdp')]:
        est = df.dropna(subset=[dep_col] + imf_bohn_vars).copy()
        if len(est) >= 50:
            m, r = fit_and_report(est[dep_col].values, est[imf_bohn_vars].values,
                                  est['iso3'].values, est['year'].values,
                                  imf_bohn_vars, dep_label)
            idx = imf_bohn_vars.index('debt_lag')
            results.append({
                'Specification': dep_label,
                'Gap Measure': 'IMF WEO',
                'Bohn β': m.beta[idx],
                'SE': m.se[idx],
                'p-value': m.pvalues[idx],
                'N': m.n_obs,
                'Countries': m.n_countries,
                'R²': m.r_squared,
            })

    # --- 1b: Hamilton (2018) filter ---
    print("\n--- 1b: Hamilton (2018) filter ---")
    # Hamilton gap: regress log(gdp_pc_ppp_{t+2}) on lags t, t-1, t-2, t-3
    # Residual × 100 = Hamilton gap
    hamilton_gaps = []
    for iso3, gdf in df.groupby('iso3'):
        gdf = gdf.sort_values('year').copy()
        if gdf['gdp_pc_ppp'].notna().sum() < 8:
            continue
        gdf['log_gdp'] = np.log(gdf['gdp_pc_ppp'].clip(lower=1))
        # Create leads and lags
        gdf['log_gdp_lead2'] = gdf['log_gdp'].shift(-2)
        for lag in range(4):
            gdf[f'log_gdp_lag{lag}'] = gdf['log_gdp'].shift(lag)

        lag_cols = [f'log_gdp_lag{i}' for i in range(4)]
        valid = gdf.dropna(subset=['log_gdp_lead2'] + lag_cols)
        if len(valid) < 6:
            continue

        from numpy.linalg import lstsq
        X_ham = np.column_stack([np.ones(len(valid))] + [valid[c].values for c in lag_cols])
        y_ham = valid['log_gdp_lead2'].values
        coefs, _, _, _ = lstsq(X_ham, y_ham, rcond=None)
        resid = y_ham - X_ham @ coefs
        valid_idx = valid.index
        for i, idx in enumerate(valid_idx):
            hamilton_gaps.append({
                'iso3': iso3,
                'year': int(gdf.loc[idx, 'year']),
                'hamilton_gap': resid[i] * 100,
            })

    ham_df = pd.DataFrame(hamilton_gaps)
    print(f"  Hamilton gap computed: {len(ham_df)} obs, {ham_df['iso3'].nunique()} countries")
    df = df.merge(ham_df, on=['iso3', 'year'], how='left')

    # Run Bohn with Hamilton gap
    for dep_label, dep_col in [('Primary Balance (Hamilton)', 'primary_bal_gdp'),
                                ('Structural Balance (Hamilton)', 'structural_bal_gdp')]:
        ham_bohn_vars = ['debt_lag', 'hamilton_gap', 'govt_exp_gap']
        ham_bohn_vars = [v for v in ham_bohn_vars if v in df.columns]
        est = df.dropna(subset=[dep_col] + ham_bohn_vars).copy()
        if len(est) >= 50:
            m, r = fit_and_report(est[dep_col].values, est[ham_bohn_vars].values,
                                  est['iso3'].values, est['year'].values,
                                  ham_bohn_vars, dep_label)
            idx = ham_bohn_vars.index('debt_lag')
            results.append({
                'Specification': dep_label,
                'Gap Measure': 'Hamilton (2018)',
                'Bohn β': m.beta[idx],
                'SE': m.se[idx],
                'p-value': m.pvalues[idx],
                'N': m.n_obs,
                'Countries': m.n_countries,
                'R²': m.r_squared,
            })

    # --- 1c: Simple controls (no gap) ---
    print("\n--- 1c: Simple controls (growth + inflation) ---")
    df['rgdp_growth_lag'] = df.groupby('iso3')['rgdp_growth'].shift(1)
    for dep_label, dep_col in [('Primary Balance (growth+infl)', 'primary_bal_gdp'),
                                ('Structural Balance (growth+infl)', 'structural_bal_gdp')]:
        simple_vars = ['debt_lag', 'rgdp_growth_lag', 'inflation']
        simple_vars = [v for v in simple_vars if v in df.columns]
        est = df.dropna(subset=[dep_col] + simple_vars).copy()
        if len(est) >= 50:
            m, r = fit_and_report(est[dep_col].values, est[simple_vars].values,
                                  est['iso3'].values, est['year'].values,
                                  simple_vars, dep_label)
            idx = simple_vars.index('debt_lag')
            results.append({
                'Specification': dep_label,
                'Gap Measure': 'Growth + Inflation',
                'Bohn β': m.beta[idx],
                'SE': m.se[idx],
                'p-value': m.pvalues[idx],
                'N': m.n_obs,
                'Countries': m.n_countries,
                'R²': m.r_squared,
            })

    # --- 1d: HP gap baseline (for comparison) ---
    print("\n--- 1d: HP gap baseline (for comparison) ---")
    hp_bohn_vars = ['debt_lag', 'output_gap_hp', 'govt_exp_gap']
    hp_bohn_vars = [v for v in hp_bohn_vars if v in df.columns]
    for dep_label, dep_col in [('Primary Balance (HP gap)', 'primary_bal_gdp'),
                                ('Structural Balance (HP gap)', 'structural_bal_gdp')]:
        est = df.dropna(subset=[dep_col] + hp_bohn_vars).copy()
        if len(est) >= 50:
            m, r = fit_and_report(est[dep_col].values, est[hp_bohn_vars].values,
                                  est['iso3'].values, est['year'].values,
                                  hp_bohn_vars, dep_label)
            idx = hp_bohn_vars.index('debt_lag')
            results.append({
                'Specification': dep_label,
                'Gap Measure': 'HP filter (λ=6.25)',
                'Bohn β': m.beta[idx],
                'SE': m.se[idx],
                'p-value': m.pvalues[idx],
                'N': m.n_obs,
                'Countries': m.n_countries,
                'R²': m.r_squared,
            })

    # --- Write output ---
    results_df = pd.DataFrame(results)
    print(f"\n{'=' * 70}")
    print("  STRUCTURAL ROBUSTNESS SUMMARY")
    print("=" * 70)
    print(results_df.to_string(index=False, float_format='%.4f'))

    # Write markdown table
    md_lines = [
        "# Structural Balance Robustness: Alternative Output Gap Measures",
        "",
        "Tests whether the structural Bohn reversal (β = -0.011, p = 0.005) is robust",
        "to alternative cyclical adjustment methods.",
        "",
        "| Dependent Variable | Gap Measure | Bohn β | SE | p-value | N | Countries | R² |",
        "|:--|:--|--:|--:|--:|--:|--:|--:|",
    ]
    for _, row in results_df.iterrows():
        sig = stars(row['p-value'])
        md_lines.append(
            f"| {row['Specification']} | {row['Gap Measure']} | "
            f"{row['Bohn β']:.4f}{sig} | {row['SE']:.4f} | {row['p-value']:.4f} | "
            f"{row['N']:,} | {row['Countries']} | {row['R²']:.3f} |"
        )

    # Interpret
    struct_rows = results_df[results_df['Specification'].str.contains('Structural')]
    n_negative = (struct_rows['Bohn β'] < 0).sum()
    n_sig = (struct_rows['p-value'] < 0.05).sum()

    md_lines.extend([
        "",
        "## Interpretation",
        "",
        f"- Of {len(struct_rows)} structural balance specifications, "
        f"{n_negative} have negative Bohn β and {n_sig} are significant at 5%.",
    ])
    if n_negative == len(struct_rows):
        md_lines.append("- The structural reversal is **robust** to alternative gap measures.")
    elif n_negative > len(struct_rows) // 2:
        md_lines.append("- The structural reversal is **largely robust** across gap measures.")
    else:
        md_lines.append("- The structural reversal is **sensitive** to gap construction.")

    md_lines.extend([
        "",
        "**Note**: PanelGLS includes entity (country) fixed effects via within-transformation.",
        "The within-R² reported captures variation explained after demeaning.",
    ])

    outfile = TABLE_DIR / "phase7_structural_robustness.md"
    outfile.write_text("\n".join(md_lines))
    print(f"\n  Saved: {outfile}")

    return df, results_df


# =====================================================================
# PART 2: r-g Claims Tightening (Homogeneous Yields)
# =====================================================================
def part2_rg_homogeneous(df):
    """Run Z → rate, Z → growth, Z → r-g on homogeneous 23-country bond yield sample."""
    print("\n" + "=" * 70)
    print("  PART 2: r-g Homogeneous Yields (23-country sample)")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    controls = [c for c in controls if c in df.columns]

    # Identify 23-country bond yield sample
    bond_countries = df.loc[df['govt_bond_10y'].notna(), 'iso3'].unique()
    print(f"  Countries with 10y bond yields: {len(bond_countries)}")

    # Compute r-g using bond yields
    df['rg_bond'] = df['govt_bond_10y'] - df['rgdp_growth']

    # Common sample: all three dependent variables available
    common_vars = demo_vars + controls
    est_base = df.dropna(subset=['govt_bond_10y', 'rgdp_growth', 'rg_bond'] + common_vars).copy()
    print(f"  Common sample: {len(est_base)} obs, {est_base['iso3'].nunique()} countries")

    results = []
    models = {}

    for dep_label, dep_col in [
        ('10y Bond Yield', 'govt_bond_10y'),
        ('Real GDP Growth', 'rgdp_growth'),
        ('r-g (bond yield)', 'rg_bond'),
    ]:
        m, r = fit_and_report(
            est_base[dep_col].values, est_base[common_vars].values,
            est_base['iso3'].values, est_base['year'].values,
            common_vars, f"Homogeneous: Z -> {dep_label}"
        )
        models[dep_label] = m
        idx_z1 = common_vars.index('Z_1')
        results.append({
            'Dependent Variable': dep_label,
            'Z₁ Coeff': m.beta[idx_z1],
            'Z₁ SE': m.se[idx_z1],
            'Z₁ p': m.pvalues[idx_z1],
            'N': m.n_obs,
            'Countries': m.n_countries,
            'R²': m.r_squared,
        })

    # Formal equality test
    # The r-g regression directly tests H₀: β_rate = β_growth
    # because r-g = rate - growth, so β_{r-g} = β_rate - β_growth
    # Also compute approximate SE for β_rate - β_growth
    m_rate = models['10y Bond Yield']
    m_growth = models['Real GDP Growth']
    m_rg = models['r-g (bond yield)']
    idx_z1 = common_vars.index('Z_1')

    beta_diff = m_rate.beta[idx_z1] - m_growth.beta[idx_z1]
    se_diff = np.sqrt(m_rate.se[idx_z1]**2 + m_growth.se[idx_z1]**2)
    t_diff = beta_diff / se_diff
    from scipy import stats
    p_diff = 2 * (1 - stats.t.cdf(abs(t_diff), df=min(m_rate.n_obs, m_growth.n_obs) - len(common_vars) - 1))

    print(f"\n{'=' * 70}")
    print("  FORMAL EQUALITY TEST: β_rate = β_growth")
    print(f"{'=' * 70}")
    print(f"  β_rate (Z₁) = {m_rate.beta[idx_z1]:.4f} (SE={m_rate.se[idx_z1]:.4f})")
    print(f"  β_growth (Z₁) = {m_growth.beta[idx_z1]:.4f} (SE={m_growth.se[idx_z1]:.4f})")
    print(f"  β_rate - β_growth = {beta_diff:.4f} (SE≈{se_diff:.4f})")
    print(f"  t = {t_diff:.3f}, p = {p_diff:.4f}")
    print(f"  Direct r-g test: Z₁ = {m_rg.beta[idx_z1]:.4f} (p={m_rg.pvalues[idx_z1]:.4f})")

    # Write markdown
    md_lines = [
        "# Homogeneous Yield r-g Analysis (23-Country Bond Yield Sample)",
        "",
        "Addresses concern that r-g \"null\" depends on heterogeneous rate measures.",
        "All three regressions estimated on identical sample of countries with 10-year government bond yields.",
        "",
        "## Panel Results (identical sample)",
        "",
        "| Dependent Variable | Z₁ Coefficient | SE | p-value | N | Countries | R² |",
        "|:--|--:|--:|--:|--:|--:|--:|",
    ]
    for row in results:
        sig = stars(row['Z₁ p'])
        md_lines.append(
            f"| {row['Dependent Variable']} | {row['Z₁ Coeff']:.2f}{sig} | "
            f"{row['Z₁ SE']:.2f} | {row['Z₁ p']:.4f} | {row['N']:,} | "
            f"{row['Countries']} | {row['R²']:.3f} |"
        )

    md_lines.extend([
        "",
        "## Formal Equality Test",
        "",
        f"The r-g regression constitutes a direct test of H₀: β_rate = β_growth.",
        "",
        f"- β_rate (Z₁) = {m_rate.beta[idx_z1]:.2f} (SE = {m_rate.se[idx_z1]:.2f})",
        f"- β_growth (Z₁) = {m_growth.beta[idx_z1]:.2f} (SE = {m_growth.se[idx_z1]:.2f})",
        f"- β_rate − β_growth = {beta_diff:.2f} (approx. SE = {se_diff:.2f})",
        f"- Approximate t = {t_diff:.3f}, p = {p_diff:.4f}",
        f"- Direct r-g regression: Z₁ = {m_rg.beta[idx_z1]:.2f} (p = {m_rg.pvalues[idx_z1]:.4f})",
        "",
        "## Interpretation",
        "",
    ])
    if m_rg.pvalues[idx_z1] > 0.1:
        md_lines.append(
            "We do not find robust evidence that demographics move r-g, even on a "
            "homogeneous sample of 23 countries with 10-year government bond yields. "
            "The demographic effects on interest rates and growth approximately cancel."
        )
    else:
        md_lines.append(
            f"On the homogeneous bond yield sample, demographics {'raise' if m_rg.beta[idx_z1] > 0 else 'lower'} "
            f"r-g (Z₁ = {m_rg.beta[idx_z1]:.2f}, p = {m_rg.pvalues[idx_z1]:.4f})."
        )

    outfile = TABLE_DIR / "phase7_rg_homogeneous.md"
    outfile.write_text("\n".join(md_lines))
    print(f"\n  Saved: {outfile}")


# =====================================================================
# PART 3: Fixed Projections
# =====================================================================
def part3_projections_fixed(df):
    """Truncate at 2040, add Bohn reaction, add uncertainty bands."""
    print("\n" + "=" * 70)
    print("  PART 3: Fixed Debt Projections (2024-2040)")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    decomp_controls = ['kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    decomp_controls = [c for c in decomp_controls if c in df.columns]

    # Estimate gap model: fiscal_gap ~ old_dep + old_dep_sq + controls
    df['fiscal_gap'] = df['govt_expenditure_gdp'] - df['govt_revenue_gdp']
    df['old_dep_sq'] = df['old_dep'] ** 2
    gap_vars = ['old_dep', 'old_dep_sq'] + decomp_controls
    est_gap = df.dropna(subset=['fiscal_gap'] + gap_vars).copy()

    m_gap = None
    if len(est_gap) >= 200:
        m_gap = PanelGLS()
        m_gap.fit(est_gap['fiscal_gap'].values, est_gap[gap_vars].values,
                  est_gap['iso3'].values, est_gap['year'].values)
        print(f"  Gap model: N={m_gap.n_obs}, R²={m_gap.r_squared:.4f}")
        print(f"    old_dep: {m_gap.beta[0]:.2f} (SE={m_gap.se[0]:.2f})")
        print(f"    old_dep_sq: {m_gap.beta[1]:.2f} (SE={m_gap.se[1]:.2f})")

    # Estimate r-g model: r_minus_g ~ Z + controls
    rg_vars = demo_vars + decomp_controls
    est_rg = df.dropna(subset=['r_minus_g'] + rg_vars).copy()
    m_rg = None
    if len(est_rg) >= 200:
        m_rg = PanelGLS()
        m_rg.fit(est_rg['r_minus_g'].values, est_rg[rg_vars].values,
                 est_rg['iso3'].values, est_rg['year'].values)
        print(f"  r-g model: N={m_rg.n_obs}, R²={m_rg.r_squared:.4f}")

    # Load future demographics
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")

    sim_countries = ['JPN', 'ITA', 'USA', 'FRA', 'DEU', 'GBR', 'ESP',
                     'KOR', 'CHN', 'BRA', 'IND', 'POL', 'THA', 'ZAF', 'MEX']
    sim_years = list(range(2024, 2041))  # Truncate at 2040
    bohn_betas = [0.0, 0.005, 0.02]  # No reaction, baseline β, stronger β

    N_MC = 500
    np.random.seed(42)

    control_means = {c: df[c].mean() for c in decomp_controls}

    all_sim_rows = []
    for iso3 in sim_countries:
        # Starting conditions
        latest = df[(df['iso3'] == iso3) & df['govt_debt_gdp'].notna()].sort_values('year').tail(1)
        if len(latest) == 0:
            print(f"  Skipping {iso3}: no debt data")
            continue

        debt_2024 = latest['govt_debt_gdp'].values[0]
        ctrl = {}
        for c in decomp_controls:
            val = latest[c].values[0] if c in latest.columns and latest[c].notna().values[0] else control_means.get(c, 0)
            ctrl[c] = val

        for beta_bohn in bohn_betas:
            # Monte Carlo draws of gap and r-g coefficients
            if m_gap is not None:
                gap_draws = np.random.multivariate_normal(
                    np.concatenate([[m_gap.constant], m_gap.beta]),
                    np.diag(np.concatenate([[m_gap.se_constant], m_gap.se])**2),
                    size=N_MC
                )
            if m_rg is not None:
                rg_draws = np.random.multivariate_normal(
                    np.concatenate([[m_rg.constant], m_rg.beta]),
                    np.diag(np.concatenate([[m_rg.se_constant], m_rg.se])**2),
                    size=N_MC
                )

            # Simulate N_MC paths
            debt_paths = np.full((N_MC, len(sim_years)), np.nan)
            for draw in range(N_MC):
                debt_t = debt_2024
                for t, year in enumerate(sim_years):
                    yr_demo = fp[(fp['iso3'] == iso3) & (fp['year'] == year)]
                    if len(yr_demo) == 0:
                        break

                    oadr = yr_demo['old_dep'].values[0] if 'old_dep' in yr_demo.columns and yr_demo['old_dep'].notna().values[0] else np.nan
                    z_vals = {}
                    for zv in demo_vars:
                        if zv in yr_demo.columns and yr_demo[zv].notna().values[0]:
                            z_vals[zv] = yr_demo[zv].values[0]
                    if np.isnan(oadr) or len(z_vals) < 3:
                        break

                    # Gap prediction with MC draw
                    if m_gap is not None:
                        x_gap = [1.0, oadr, oadr**2] + [ctrl.get(c, 0) for c in decomp_controls]
                        gap_pred = np.dot(gap_draws[draw], x_gap)
                    else:
                        gap_pred = 0.0

                    # r-g prediction with MC draw — clip to [-5, 15] to avoid explosive outliers
                    if m_rg is not None:
                        x_rg = [1.0] + [z_vals.get(zv, 0) for zv in demo_vars] + [ctrl.get(c, 0) for c in decomp_controls]
                        rg_pred = np.clip(np.dot(rg_draws[draw], x_rg), -5, 15)
                    else:
                        rg_pred = 0.0

                    # Clip gap prediction to reasonable range
                    gap_pred = np.clip(gap_pred, -15, 15)

                    # Policy reaction: adjust gap by Bohn reaction to debt above 2024
                    if beta_bohn > 0:
                        gap_pred = gap_pred - beta_bohn * max(0, debt_t - debt_2024)

                    # Debt dynamics
                    interest_acc = (rg_pred / 100) * debt_t
                    debt_t_new = debt_t + interest_acc + gap_pred
                    debt_t = max(debt_t_new, 0)  # Floor at 0
                    debt_t = min(debt_t, 500)     # Cap at 500% to avoid explosive paths
                    debt_paths[draw, t] = debt_t

            # Compute percentiles
            for t, year in enumerate(sim_years):
                valid = debt_paths[:, t]
                valid = valid[~np.isnan(valid)]
                if len(valid) == 0:
                    continue
                all_sim_rows.append({
                    'iso3': iso3,
                    'year': year,
                    'bohn_beta': beta_bohn,
                    'debt_median': np.median(valid),
                    'debt_p10': np.percentile(valid, 10),
                    'debt_p90': np.percentile(valid, 90),
                    'debt_mean': np.mean(valid),
                    'n_draws': len(valid),
                })

    sim_df = pd.DataFrame(all_sim_rows)
    if len(sim_df) == 0:
        print("  No simulation results!")
        return

    sim_df.to_csv(TABLE_DIR / "phase7_projections_fixed.csv", index=False)
    print(f"\n  Saved simulation: {len(sim_df)} rows")

    # Write markdown table
    md_lines = [
        "# Fixed Debt Projections (2024-2040)",
        "",
        "Addresses reviewer concerns: truncated at 2040, includes Bohn policy reaction",
        "and Monte Carlo uncertainty bands (N=500 draws from coefficient distributions).",
        "",
        "Three scenarios:",
        "- **No reaction** (β=0): No-policy-change mechanical paths",
        "- **Baseline reaction** (β=0.005): Estimated Bohn coefficient — for each 10pp debt above 2024,",
        "  primary balance improves by 0.05pp",
        "- **Strong reaction** (β=0.02): Stronger fiscal rule — for each 10pp above 2024, +0.2pp improvement",
        "",
    ]

    # Summary table for key years
    for beta in bohn_betas:
        label = "No reaction" if beta == 0 else f"β={beta}" + (" (baseline)" if beta == 0.005 else " (strong)")
        md_lines.extend([
            f"## Scenario: {label}",
            "",
            "| Country | 2024 | 2030 | 2035 | 2040 |",
            "|:--|--:|--:|--:|--:|",
        ])
        subset = sim_df[sim_df['bohn_beta'] == beta]
        for iso3 in sim_countries:
            row_data = [iso3]
            for year in [2024, 2030, 2035, 2040]:
                yr = subset[(subset['iso3'] == iso3) & (subset['year'] == year)]
                if len(yr) > 0:
                    med = yr['debt_median'].values[0]
                    p10 = yr['debt_p10'].values[0]
                    p90 = yr['debt_p90'].values[0]
                    if year == 2024:
                        row_data.append(f"{med:.0f}")
                    else:
                        row_data.append(f"{med:.0f} [{p10:.0f}, {p90:.0f}]")
                else:
                    row_data.append("—")
            md_lines.append("| " + " | ".join(row_data) + " |")
        md_lines.append("")

    md_lines.extend([
        "## Notes",
        "",
        "- Values show median [10th, 90th percentile] from 500 Monte Carlo draws.",
        "- Debt floored at 0% of GDP.",
        "- Projections are **no-policy-change mechanical paths** (β=0) or **stylized Bohn reaction** scenarios.",
        "- Demographic inputs from UN medium-variant population projections.",
        "- r-g and fiscal gap estimated from historical panel; country-specific control values held constant.",
    ])

    outfile = TABLE_DIR / "phase7_projections_fixed.md"
    outfile.write_text("\n".join(md_lines))
    print(f"\n  Saved: {outfile}")


# =====================================================================
# PART 4: OADR Units and Marginal Effects
# =====================================================================
def part4_marginal_effects(df):
    """Compute marginal effects of OADR at various levels."""
    print("\n" + "=" * 70)
    print("  PART 4: OADR Marginal Effects")
    print("=" * 70)

    decomp_controls = ['debt_lag', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    decomp_controls = [c for c in decomp_controls if c in df.columns]

    df['old_dep_sq'] = df['old_dep'] ** 2
    oadr_vars = ['old_dep', 'old_dep_sq'] + decomp_controls

    # Estimate expenditure, revenue, and gap models
    models = {}
    for dep_label, dep_col in [
        ('Expenditure/GDP', 'govt_expenditure_gdp'),
        ('Revenue/GDP', 'govt_revenue_gdp'),
        ('Fiscal Gap', 'fiscal_gap'),
    ]:
        if dep_col == 'fiscal_gap':
            df['fiscal_gap'] = df['govt_expenditure_gdp'] - df['govt_revenue_gdp']
        est = df.dropna(subset=[dep_col] + oadr_vars).copy()
        if len(est) >= 200:
            m, r = fit_and_report(
                est[dep_col].values, est[oadr_vars].values,
                est['iso3'].values, est['year'].values,
                oadr_vars, f"OADR Marginal: {dep_label}"
            )
            idx_lin = oadr_vars.index('old_dep')
            idx_sq = oadr_vars.index('old_dep_sq')
            models[dep_label] = {
                'beta_lin': m.beta[idx_lin],
                'beta_sq': m.beta[idx_sq],
                'se_lin': m.se[idx_lin],
                'se_sq': m.se[idx_sq],
                'p_lin': m.pvalues[idx_lin],
                'p_sq': m.pvalues[idx_sq],
                'N': m.n_obs,
            }

    # Compute marginal effects at evaluation points
    eval_points = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30]
    sample_mean_oadr = df['old_dep'].mean()
    print(f"\n  Sample mean OADR: {sample_mean_oadr:.4f} ({sample_mean_oadr*100:.1f}%)")

    me_rows = []
    for oadr in eval_points:
        row = {'OADR': oadr, 'OADR (%)': f"{oadr*100:.0f}%"}
        for dep_label, params in models.items():
            me = params['beta_lin'] + 2 * params['beta_sq'] * oadr
            # Approximate SE of marginal effect
            se_me = np.sqrt(params['se_lin']**2 + (2 * oadr * params['se_sq'])**2)
            row[f'{dep_label} ME'] = me
            row[f'{dep_label} SE'] = se_me
        me_rows.append(row)

    # Compute integral effects for 10pp increases
    integral_rows = []
    for start in [0.05, 0.10, 0.15, 0.20]:
        end = start + 0.10
        row = {'Transition': f"{start*100:.0f}% → {end*100:.0f}%"}
        for dep_label, params in models.items():
            # Integral of ME(OADR) from start to end
            # = β_lin × (end - start) + β_sq × (end² - start²)
            integral = params['beta_lin'] * (end - start) + params['beta_sq'] * (end**2 - start**2)
            row[f'{dep_label} Effect (pp)'] = integral
        integral_rows.append(row)

    # Write markdown
    md_lines = [
        "# OADR Marginal Effects and Interpretation",
        "",
        f"Sample mean OADR = {sample_mean_oadr:.3f} ({sample_mean_oadr*100:.1f}%).",
        "OADR is measured as a fraction (0.10 = 10%). Coefficients refer to unit changes in this fraction.",
        "",
        "## Model Coefficients",
        "",
        "| Dependent Variable | OADR (linear) | SE | p-value | OADR² | SE | p-value | N |",
        "|:--|--:|--:|--:|--:|--:|--:|--:|",
    ]
    for dep_label, params in models.items():
        md_lines.append(
            f"| {dep_label} | {params['beta_lin']:.1f}{stars(params['p_lin'])} | "
            f"{params['se_lin']:.1f} | {params['p_lin']:.4f} | "
            f"{params['beta_sq']:.1f}{stars(params['p_sq'])} | "
            f"{params['se_sq']:.1f} | {params['p_sq']:.4f} | {params['N']:,} |"
        )

    md_lines.extend([
        "",
        "## Marginal Effects at Representative OADR Levels",
        "",
        "ME(OADR) = β_linear + 2 × β_quadratic × OADR",
        "",
        "| OADR | Expenditure ME | Revenue ME | Fiscal Gap ME |",
        "|:--|--:|--:|--:|",
    ])
    for row in me_rows:
        line = f"| {row['OADR (%)']} |"
        for dep_label in models.keys():
            me = row.get(f'{dep_label} ME', 0)
            se = row.get(f'{dep_label} SE', 0)
            line += f" {me:.1f} (±{se:.1f}) |"
        md_lines.append(line)

    md_lines.extend([
        "",
        "## Effect of 10pp OADR Increase (Integral of Marginal Effects)",
        "",
        "| Transition | Expenditure (pp GDP) | Revenue (pp GDP) | Fiscal Gap (pp GDP) |",
        "|:--|--:|--:|--:|",
    ])
    for row in integral_rows:
        line = f"| {row['Transition']} |"
        for dep_label in models.keys():
            val = row.get(f'{dep_label} Effect (pp)', 0)
            line += f" {val:+.1f} |"
        md_lines.append(line)

    md_lines.extend([
        "",
        "## Interpretation for Paper",
        "",
    ])
    # Find sample-mean interpretation
    if 'Expenditure/GDP' in models and 'Revenue/GDP' in models:
        me_exp_mean = models['Expenditure/GDP']['beta_lin'] + 2 * models['Expenditure/GDP']['beta_sq'] * sample_mean_oadr
        me_rev_mean = models['Revenue/GDP']['beta_lin'] + 2 * models['Revenue/GDP']['beta_sq'] * sample_mean_oadr
        md_lines.append(
            f"At the sample mean OADR ({sample_mean_oadr*100:.1f}%), the marginal effect of a "
            f"1pp OADR increase is +{me_exp_mean/100:.2f}pp on expenditure/GDP and "
            f"+{me_rev_mean/100:.2f}pp on revenue/GDP."
        )
        # 10pp transition from mean
        start = sample_mean_oadr
        end = start + 0.10
        exp_eff = models['Expenditure/GDP']['beta_lin'] * 0.10 + models['Expenditure/GDP']['beta_sq'] * (end**2 - start**2)
        rev_eff = models['Revenue/GDP']['beta_lin'] * 0.10 + models['Revenue/GDP']['beta_sq'] * (end**2 - start**2)
        md_lines.append(
            f"\nA 10pp OADR increase from {start*100:.0f}% to {end*100:.0f}% implies "
            f"+{exp_eff:.1f}pp expenditure/GDP and +{rev_eff:.1f}pp revenue/GDP, "
            f"opening a fiscal gap of {exp_eff - rev_eff:.1f}pp of GDP."
        )

    outfile = TABLE_DIR / "phase7_marginal_effects.md"
    outfile.write_text("\n".join(md_lines))
    print(f"\n  Saved: {outfile}")


# =====================================================================
# PART 5: Secondary Fixes
# =====================================================================
def part5_secondary(df):
    """Same-sample PB vs SB comparison and Bohn β interpretation."""
    print("\n" + "=" * 70)
    print("  PART 5: Secondary Fixes")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    bohn_controls = ['output_gap_hp', 'govt_exp_gap']
    bohn_controls = [c for c in bohn_controls if c in df.columns]

    # --- 5a: Same-sample PB vs SB ---
    print("\n--- 5a: Same-sample PB vs SB comparison ---")

    # Countries with BOTH primary and structural balance
    both_vars = ['primary_bal_gdp', 'structural_bal_gdp', 'debt_lag'] + bohn_controls
    est = df.dropna(subset=both_vars).copy()
    print(f"  Same sample: {len(est)} obs, {est['iso3'].nunique()} countries")

    results_5a = {}
    for dep_label, dep_col in [('Primary Balance', 'primary_bal_gdp'),
                                ('Structural Balance', 'structural_bal_gdp')]:
        bohn_vars = ['debt_lag'] + bohn_controls
        m, r = fit_and_report(
            est[dep_col].values, est[bohn_vars].values,
            est['iso3'].values, est['year'].values,
            bohn_vars, f"Same-Sample: {dep_label}"
        )
        idx = bohn_vars.index('debt_lag')
        results_5a[dep_label] = {
            'model': m,
            'bohn_beta': m.beta[idx],
            'bohn_se': m.se[idx],
            'bohn_p': m.pvalues[idx],
            'N': m.n_obs,
            'Countries': m.n_countries,
            'R²': m.r_squared,
        }
        # Also store all coefficients
        results_5a[dep_label]['all_coefs'] = {}
        for i, v in enumerate(bohn_vars):
            results_5a[dep_label]['all_coefs'][v] = {
                'coef': m.beta[i], 'se': m.se[i], 'p': m.pvalues[i]
            }

    # --- 5b: Bohn β debt dynamics interpretation ---
    print("\n--- 5b: Bohn β debt dynamics interpretation ---")

    bohn_beta = 0.005  # Estimated baseline
    debt_levels = [50, 75, 100, 150, 200, 300, 400]
    rg_scenarios = [0, 1, 2, 3]

    pb_response_rows = []
    for d in debt_levels:
        pb = bohn_beta * d
        pb_response_rows.append({
            'Debt/GDP (%)': d,
            'PB response (pp GDP)': pb,
        })

    stable_debt_rows = []
    # autonomous PB (from sample): use the constant from the Bohn regression
    # pb_autonomous ≈ constant ≈ intercept of Bohn regression
    if 'Primary Balance' in results_5a:
        pb_auto = results_5a['Primary Balance']['model'].constant
    else:
        pb_auto = -0.5  # sample mean

    for rg in rg_scenarios:
        # Stability: Δd = (r-g)/100 × d - pb = [(r-g)/100 - β] × d - pb_auto
        # Equilibrium: d* = pb_auto / [(r-g)/100 - β]
        # Stable iff (r-g)/100 < β (debt coefficient negative → self-correcting)
        denom = rg / 100 - bohn_beta
        if abs(denom) < 0.0001:
            d_star = "Knife-edge"
            note = "β exactly equals r-g/100"
        elif denom < 0:
            # β > r-g/100: stable system. d* = pb_auto / denom
            d_star_val = pb_auto / denom
            if d_star_val > 0:
                d_star = f"{d_star_val:.0f}%"
                note = "Stable equilibrium (β > r-g)"
            else:
                d_star = "0% (self-correcting)"
                note = "Debt converges to zero"
        else:
            # β < r-g/100: unstable — Bohn reaction too weak
            d_star = "No equilibrium"
            note = "Explosive (β < r-g)"
        stable_debt_rows.append({
            'r-g (pp)': rg,
            'Denominator': f"{denom:.4f}",
            'Stable debt/GDP': d_star,
            'Note': note,
        })

    # Write markdown
    md_lines = [
        "# Bohn Coefficient Interpretation and Same-Sample Comparison",
        "",
        "## Part A: Same-Sample Primary vs Structural Balance Bohn Test",
        "",
        "Both regressions estimated on identical sample where both measures are available.",
        "",
        "| Variable | Primary Balance | | Structural Balance | |",
        "|:--|--:|--:|--:|--:|",
        "| | Coeff (SE) | p | Coeff (SE) | p |",
    ]

    pb_res = results_5a.get('Primary Balance', {})
    sb_res = results_5a.get('Structural Balance', {})
    bohn_vars_list = ['debt_lag'] + bohn_controls

    for v in bohn_vars_list:
        pb_c = pb_res.get('all_coefs', {}).get(v, {})
        sb_c = sb_res.get('all_coefs', {}).get(v, {})
        md_lines.append(
            f"| {v} | {pb_c.get('coef', 0):.4f} ({pb_c.get('se', 0):.4f}) | "
            f"{pb_c.get('p', 1):.4f} | "
            f"{sb_c.get('coef', 0):.4f} ({sb_c.get('se', 0):.4f}) | "
            f"{sb_c.get('p', 1):.4f} |"
        )

    md_lines.extend([
        f"| **N** | {pb_res.get('N', 0):,} | | {sb_res.get('N', 0):,} | |",
        f"| **Countries** | {pb_res.get('Countries', 0)} | | {sb_res.get('Countries', 0)} | |",
        f"| **R²** | {pb_res.get('R²', 0):.3f} | | {sb_res.get('R²', 0):.3f} | |",
    ])

    md_lines.extend([
        "",
        "## Part B: Bohn β in Debt Dynamics Terms",
        "",
        f"Estimated β = {bohn_beta} (baseline primary balance Bohn coefficient).",
        "",
        "### Primary Balance Response to Debt Level",
        "",
        "| Debt/GDP | PB Response (pp GDP) | Interpretation |",
        "|--:|--:|:--|",
    ])
    for row in pb_response_rows:
        d = row['Debt/GDP (%)']
        pb = row['PB response (pp GDP)']
        md_lines.append(f"| {d}% | {pb:.2f} | {'Negligible' if pb < 0.5 else 'Weak' if pb < 1.0 else 'Moderate'} |")

    md_lines.extend([
        "",
        f"β ≈ 0.005 implies that a 10pp debt increase yields only 0.05pp primary balance improvement.",
        "This is extremely weak relative to the adjustment required for stabilization.",
        "",
        "### Implied Stable Debt Levels (at given r-g)",
        "",
        f"Autonomous primary balance (constant) ≈ {pb_auto:.2f}pp GDP.",
        "Stable debt solves: β × d + pb_auto = (r-g)/100 × d.",
        "",
        "| r-g (pp) | Stable Debt/GDP | Note |",
        "|--:|:--|:--|",
    ])
    for row in stable_debt_rows:
        md_lines.append(f"| {row['r-g (pp)']} | {row['Stable debt/GDP']} | {row['Note']} |")

    md_lines.extend([
        "",
        "## Key Takeaway",
        "",
        "β ≈ 0.005 implies that a 10pp debt increase yields only 0.05pp primary balance "
        "improvement — extremely weak relative to the adjustment required for stabilization. "
        "At r-g = 0% (as in Japan), debt converges to an equilibrium around "
        f"{abs(pb_auto / bohn_beta):.0f}% of GDP. "
        "At r-g ≥ 1%, the Bohn reaction is too weak to offset interest costs, "
        "and debt is explosive.",
    ])

    outfile = TABLE_DIR / "phase7_bohn_interpretation.md"
    outfile.write_text("\n".join(md_lines))
    print(f"\n  Saved: {outfile}")


# =====================================================================
# MAIN
# =====================================================================
def main():
    print("=" * 70)
    print("PHASE 7: Reviewer Response — Robustness and Clarifications")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "fiscal_panel.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Part 1: Structural balance robustness
    df, struct_results = part1_structural_robustness(df)

    # Part 2: r-g homogeneous yields
    part2_rg_homogeneous(df)

    # Part 3: Fixed projections
    part3_projections_fixed(df)

    # Part 4: OADR marginal effects
    part4_marginal_effects(df)

    # Part 5: Secondary fixes
    part5_secondary(df)

    print(f"\n{'=' * 70}")
    print("  PHASE 7 COMPLETE")
    print("  Output files:")
    for f in sorted(TABLE_DIR.glob("phase7_*.md")):
        print(f"    {f.name}")
    print("=" * 70)


if __name__ == "__main__":
    main()
